"""Migrate onboarding state [ff538a321a92].

Revision ID: ff538a321a92
Revises: 0.80.2
Create Date: 2025-04-11 09:30:03.324310

"""

import json

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "ff538a321a92"
down_revision = "0.80.2"
branch_labels = None
depends_on = None


def upgrade() -> None:
    """Upgrade database schema and/or data, creating a new revision."""
    with op.batch_alter_table("server_settings", schema=None) as batch_op:
        batch_op.alter_column(
            "onboarding_state",
            existing_type=sa.VARCHAR(),
            type_=sa.TEXT(),
            existing_nullable=True,
        )

    connection = op.get_bind()

    meta = sa.MetaData()
    meta.reflect(only=("server_settings",), bind=connection)

    server_settings_table = sa.Table("server_settings", meta)

    existing_onboarding_state = connection.execute(
        sa.select(server_settings_table.c.onboarding_state)
    ).scalar_one_or_none()

    if not existing_onboarding_state:
        return

    state = json.loads(existing_onboarding_state)

    meta = sa.MetaData()
    meta.reflect(
        only=(
            "pipeline_run",
            "stack_component",
            "stack",
            "stack_composition",
            "pipeline_deployment",
        ),
        bind=connection,
    )

    pipeline_run_table = sa.Table("pipeline_run", meta)
    stack_component_table = sa.Table("stack_component", meta)
    stack_table = sa.Table("stack", meta)
    stack_composition_table = sa.Table("stack_composition", meta)
    pipeline_deployment_table = sa.Table("pipeline_deployment", meta)

    stack_with_remote_artifact_store_count = connection.execute(
        sa.select(sa.func.count(stack_table.c.id))
        .where(stack_composition_table.c.stack_id == stack_table.c.id)
        .where(
            stack_composition_table.c.component_id
            == stack_component_table.c.id
        )
        .where(stack_component_table.c.flavor != "local")
        .where(stack_component_table.c.type == "artifact_store")
    ).scalar()
    if (
        stack_with_remote_artifact_store_count
        and stack_with_remote_artifact_store_count > 0
    ):
        state.append("stack_with_remote_artifact_store_created")

    pipeline_run_with_remote_artifact_store_count = connection.execute(
        sa.select(sa.func.count(pipeline_run_table.c.id))
        .where(
            pipeline_run_table.c.deployment_id
            == pipeline_deployment_table.c.id
        )
        .where(pipeline_deployment_table.c.stack_id == stack_table.c.id)
        .where(stack_composition_table.c.stack_id == stack_table.c.id)
        .where(
            stack_composition_table.c.component_id
            == stack_component_table.c.id
        )
        .where(stack_component_table.c.flavor != "local")
        .where(stack_component_table.c.type == "artifact_store")
    ).scalar()
    if (
        pipeline_run_with_remote_artifact_store_count
        and pipeline_run_with_remote_artifact_store_count > 0
    ):
        state.append("pipeline_run_with_remote_artifact_store")
        state.append("production_setup_completed")

    # Remove duplicate keys
    state = list(set(state))

    connection.execute(
        sa.update(server_settings_table).values(
            onboarding_state=json.dumps(state)
        )
    )


def downgrade() -> None:
    """Downgrade database schema and/or data back to the previous revision."""
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("server_settings", schema=None) as batch_op:
        batch_op.alter_column(
            "onboarding_state",
            existing_type=sa.TEXT(),
            type_=sa.VARCHAR(),
            existing_nullable=True,
        )
    # ### end Alembic commands ###
